import numpy as np
import matplotlib.pyplot as plt 

max_steps = 100000
warmup_steps = 10
factor = 1.0
model_size = 768


def get_lr(step):
    return factor * model_size ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5))

x = list(range(1, max_steps))
lr = [get_lr(i) for i in x]

plt.plot(x, lr)
plt.show()